Skip to content

Triton MLA perf fixes#33529

Merged
MatthewBonanni merged 17 commits intovllm-project:mainfrom
koush:triton-mla-perf
Apr 2, 2026
Merged

Triton MLA perf fixes#33529
MatthewBonanni merged 17 commits intovllm-project:mainfrom
koush:triton-mla-perf

Conversation

@koush
Copy link
Copy Markdown
Contributor

@koush koush commented Feb 2, 2026

Purpose

Triton MLA on sm120 performance degrades on batch size 1 as context length increases.

This perf issue has been bugging me for a while as it made deepseek and Kimi k2 unusable, and when Kimi k2.5 was released I finally got around to digging into it. I'm familiar with w/ cuda but this is my first foray into triton.

The primary issues are suboptimal kv splitting during low batch count resulting in underutilized SM and unnecessary load of Q vector, since c_kv contains both.

Test Plan

Test Kimi 2.5, Deepseek v2, Qwen 235B.

Test Result

Models load and generate correctly, Kimi 2.5 shows notable improvement as seen below.

Testing with -dcp 8 and 80k context

Concurrent Requests OLD Peak Throughput (tok/s) NEW Peak Throughput (tok/s) Peak Throughput Increase (%)
1 (128 tokens) 54.00 55.00 1.85%
8 (128 tokens) 272.00 272.00 0.00%
1 (80K tokens) 34.00 52.00 52.94%
8 (80K tokens) 176.00 184.00 4.55%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@dosubot
Copy link
Copy Markdown

dosubot bot commented Feb 2, 2026

Related Documentation

No published documentation to review for changes on this repository.

Write your first living document

How did I do? Any feedback?  Join Discord

@mergify mergify bot added the v1 label Feb 2, 2026
Comment thread vllm/v1/attention/backends/mla/triton_mla.py Outdated
Comment thread vllm/v1/attention/ops/triton_decode_attention.py
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for Triton MLA, particularly for long context lengths. The changes include enabling CUDA graph support for decode, which can reduce kernel launch overhead. A more adaptive heuristic for calculating num_kv_splits is introduced, which should improve parallelism on different hardware. The core Triton attention kernel has been optimized by improving memory access patterns for better coalescing, reordering loads to hide latency, and adding cache modifier hints for the Triton compiler. These are solid, expert-level optimizations that should lead to the claimed performance improvements. The changes look correct and well-implemented.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 2, 2026

Hi @koush, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Comment thread vllm/v1/attention/ops/triton_decode_attention.py Outdated
Comment thread vllm/v1/attention/ops/triton_decode_attention.py
Comment thread vllm/v1/attention/ops/triton_decode_attention.py Outdated
Comment thread vllm/v1/attention/backends/mla/triton_mla.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 2, 2026

Hi @koush, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@koush
Copy link
Copy Markdown
Contributor Author

koush commented Feb 2, 2026

benchmarks on Kimi k2.5

Summary

Context OLD (tok/s) NEW (tok/s) Improvement
Short (100 in / 400 out) 69.14 72.37 1.05x
Long (80K in / 400 out) 5.61 24.47 4.36x
OLD

============ Serving Benchmark Result ============
Successful requests:                     2         
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  11.57     
Total input tokens:                      200       
Total generated tokens:                  800       
Request throughput (req/s):              0.17      
Output token throughput (tok/s):         69.14     
Peak output token throughput (tok/s):    72.00     
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          86.42     
---------------Time to First Token----------------
Mean TTFT (ms):                          76.94     
Median TTFT (ms):                        76.94     
P99 TTFT (ms):                           81.87     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.31     
Median TPOT (ms):                        14.31     
P99 TPOT (ms):                           14.45     
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.31     
Median ITL (ms):                         14.15     
P99 ITL (ms):                            20.66     
==================================================



============ Serving Benchmark Result ============
Successful requests:                     2         
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  142.69    
Total input tokens:                      160000    
Total generated tokens:                  800       
Request throughput (req/s):              0.01      
Output token throughput (tok/s):         5.61      
Peak output token throughput (tok/s):    9.00      
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          1126.90   
---------------Time to First Token----------------
Mean TTFT (ms):                          25966.21  
Median TTFT (ms):                        25966.21  
P99 TTFT (ms):                           25980.01  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          113.73    
Median TPOT (ms):                        113.73    
P99 TPOT (ms):                           113.86    
---------------Inter-token Latency----------------
Mean ITL (ms):                           113.73    
Median ITL (ms):                         113.74    
P99 ITL (ms):                            114.33    
==================================================

NEW

============ Serving Benchmark Result ============
Successful requests:                     2         
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  11.06     
Total input tokens:                      200       
Total generated tokens:                  800       
Request throughput (req/s):              0.18      
Output token throughput (tok/s):         72.37     
Peak output token throughput (tok/s):    74.00     
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          90.46     
---------------Time to First Token----------------
Mean TTFT (ms):                          74.98     
Median TTFT (ms):                        74.98     
P99 TTFT (ms):                           78.82     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          13.66     
Median TPOT (ms):                        13.66     
P99 TPOT (ms):                           13.68     
---------------Inter-token Latency----------------
Mean ITL (ms):                           13.66     
Median ITL (ms):                         13.66     
P99 ITL (ms):                            14.08     
==================================================

============ Serving Benchmark Result ============
Successful requests:                     2         
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  32.69     
Total input tokens:                      160000    
Total generated tokens:                  800       
Request throughput (req/s):              0.06      
Output token throughput (tok/s):         24.47     
Peak output token throughput (tok/s):    26.00     
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          4918.99   
---------------Time to First Token----------------
Mean TTFT (ms):                          490.87    
Median TTFT (ms):                        490.87    
P99 TTFT (ms):                           491.32    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          39.73     
Median TPOT (ms):                        39.73     
P99 TPOT (ms):                           39.76     
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.73     
Median ITL (ms):                         39.72     
P99 ITL (ms):                            40.34     
==================================================


@koush koush changed the title Triton MLA perf fixes (4x improvement at 80k context) Triton MLA GQA perf fixes (4x improvement at 80k context) Feb 2, 2026
root added 5 commits February 2, 2026 02:47
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Comment thread vllm/v1/attention/backends/mla/triton_mla.py Outdated
koush added 3 commits February 2, 2026 20:02
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Comment thread vllm/v1/attention/backends/mla/triton_mla.py Outdated
koush added 2 commits February 3, 2026 06:46
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
@koush
Copy link
Copy Markdown
Contributor Author

koush commented Feb 3, 2026

@tjtanaa I tested -dcp with cuda graph and the model worked fine, but I did notice performance degradation compared to not using cuda graphs. I'm not sure what the expected behavior is supposed to be.

Regardless, since dcp is the recommended way to run it, as it saves kv cache space by preventing data duplication, and cuda graphs reduced performance, I backed out that part of my change.

@mgehre-amd
Copy link
Copy Markdown
Contributor

gsm8k

I assume you mean this?

This? docs.vllm.ai/projects/ascend/en/latest/developer_guide/evaluation/using_lm_eval.html

I'm running this like

lm_eval --model vllm \
  --model_args pretrained=<model_name>,max_model_len=4096,gpu_memory_utilization=0.8 \
  --tasks gsm8k \
  --limit 200 \
  --batch_size 1

@koush
Copy link
Copy Markdown
Contributor Author

koush commented Mar 27, 2026

@mgoin

2026-03-27:18:18:44 WARNING  [config.evaluate_config:281] --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.
2026-03-27:18:18:49 INFO     [_cli.run:376] Selected Tasks: ['gsm8k']
2026-03-27:18:18:49 INFO     [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-03-27:18:18:49 INFO     [evaluator:236] Initializing local-completions model, with arguments: {'model': 'moonshotai/Kimi-K2.5', 'model_name': 'qwen3_coder', 'base_url': 'http://127.0.0.1:8000/v1/completions', 'tokenized_requests': False, 'trust_remote_code': True, 'api_key': 'token-abc123'}
2026-03-27:18:18:49 INFO     [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-03-27:18:18:49 INFO     [models.api_models:172] Using max length 2048 - 1
2026-03-27:18:18:49 INFO     [models.api_models:175] Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1.
2026-03-27:18:18:49 INFO     [models.api_models:193] Using tokenizer huggingface
2026-03-27:18:18:53 INFO     [tasks:700] Selected tasks:
2026-03-27:18:18:53 INFO     [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-03-27:18:18:53 INFO     [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-03-27:18:18:53 INFO     [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 847.50it/s]
2026-03-27:18:18:54 INFO     [evaluator:584] Running generate_until requests
2026-03-27:18:18:54 INFO     [models.api_models:733] Tokenized requests are disabled. Context + generation length is not checked.
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [07:21<00:00,  2.21s/it]
fatal: not a git repository (or any parent up to mount point /mnt)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
2026-03-27:18:26:18 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'moonshotai/Kimi-K2.5', 'base_url': 'http://127.0.0.1:8000/v1/completions', 'tokenized_requests': False}), gen_kwargs: ({}), limit: 200.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.955|±  |0.0147|
|     |       |strict-match    |     5|exact_match|↑  |0.955|±  |0.0147|

Is this what you need?

@koush koush requested a review from MatthewBonanni as a code owner March 27, 2026 19:31
Signed-off-by: Koushik Dutta <koushd@gmail.com>
@koush koush force-pushed the triton-mla-perf branch from 65fb1f7 to 19a10fc Compare March 29, 2026 05:19
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @koush.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
@mergify mergify bot removed the needs-rebase label Apr 1, 2026
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MatthewBonanni MatthewBonanni merged commit d9408ff into vllm-project:main Apr 2, 2026
59 checks passed
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 3, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
USTCKAY pushed a commit to USTCKAY/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: Song Kai <songkai05@baidu.com>
rishitdholakia13 pushed a commit to rishitdholakia13/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: rishitdholakia13 <rishit+github@cohere.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: Rishi Puri <riship@nvidia.com>
big-yellow-duck pushed a commit to EmbeddedLLM/vllm that referenced this pull request Apr 8, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
jackcfwang pushed a commit to jackcfwang/vllm that referenced this pull request Apr 10, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: jackcfwang <jackcfwang@tencent.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 10, 2026
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants